package main

import (
	"crypto/md5"
	"encoding/hex"
	"fmt"
	"io"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"regexp"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/PuerkitoBio/goquery"
	"github.com/cheggaaa/pb/v3"
	"github.com/spf13/cobra"
)

type DownloadTask struct {
	URL           string
	TargetPath    string
	CurrentSize   int64
	TotalSize     int64
	Completed     bool
	Retries       int
	MaxRetries    int
	Bar           *pb.ProgressBar
	DownloadedURL *sync.Map
}

var activeTaskCount int32 // 原子计数器
var taskMutex sync.Mutex  // 保护计数器的互斥锁

func main() {
	var (
		concurrency int
		maxDepth    int
		outputDir   string
		maxRetries  int
		urls        []string
		resources   bool
	)

	rootCmd := &cobra.Command{
		Use:   "downloader",
		Short: "A concurrent website downloader",
		PreRun: func(cmd *cobra.Command, args []string) {
			if len(urls) == 0 {
				fmt.Println("请提供至少一个URL")
				os.Exit(1)
			}
		},
		Run: func(cmd *cobra.Command, args []string) {
			// 创建输出目录
			if err := os.MkdirAll(outputDir, 0755); err != nil {
				fmt.Printf("创建输出目录失败: %v\n", err)
				return
			}

			// 初始化下载任务
			tasks := make([]*DownloadTask, 0, len(urls))
			var downloadedURLs sync.Map

			for _, u := range urls {
				parsedURL, err := url.Parse(u)
				if err != nil {
					fmt.Printf("无效的URL %s: %v\n", u, err)
					continue
				}

				// 为每个URL创建一个子目录
				urlDir := filepath.Join(outputDir, parsedURL.Hostname())
				if err := os.MkdirAll(urlDir, 0755); err != nil {
					fmt.Printf("创建URL目录失败 %s: %v\n", urlDir, err)
					continue
				}

				// 确定目标文件路径
				fileName := filepath.Base(parsedURL.Path)
				if fileName == "" || fileName == "." || fileName == "/" || fileName == "//" {
					fileName = "index.html"
				}

				// 初始化下载任务
				task := &DownloadTask{
					URL:           u,
					TargetPath:    filepath.Join(urlDir, fileName),
					MaxRetries:    maxRetries,
					DownloadedURL: &downloadedURLs,
				}
				tasks = append(tasks, task)
			}

			// 创建工作池
			var wg sync.WaitGroup
			taskCh := make(chan *DownloadTask, len(tasks)*10) // 增大缓冲区以容纳资源下载任务

			// 在启动工作协程前初始化
			activeTaskCount = int32(len(tasks))

			// 启动工作协程
			for i := 0; i < concurrency; i++ {
				wg.Add(1)
				go func() {
					defer wg.Done()
					for task := range taskCh {
						downloadWithRetry(task, maxDepth, resources, taskCh, &activeTaskCount, &taskMutex)
					}
				}()
			}

			// 发送任务到通道
			for _, task := range tasks {
				taskCh <- task
			}

			// 等待所有下载完成
			wg.Wait()

			// 等待通道关闭
			for range taskCh {
			}

			fmt.Println("所有下载任务已完成")
		},
	}

	rootCmd.Flags().IntVarP(&concurrency, "concurrency", "c", 5, "并发下载数")
	rootCmd.Flags().IntVarP(&maxDepth, "depth", "d", 1, "递归下载深度")
	rootCmd.Flags().StringVarP(&outputDir, "output", "o", "downloads", "输出目录")
	rootCmd.Flags().IntVarP(&maxRetries, "retries", "r", 3, "下载失败重试次数")
	rootCmd.Flags().StringSliceVarP(&urls, "urls", "u", []string{}, "要下载的URL列表")
	rootCmd.Flags().BoolVarP(&resources, "resources", "s", true, "是否下载资源文件(图片、CSS、JS等)")

	if err := rootCmd.Execute(); err != nil {
		fmt.Println(err)
		os.Exit(1)
	}
}

func downloadWithRetry(task *DownloadTask, maxDepth int, downloadResources bool, taskCh chan<- *DownloadTask, activeCount *int32, mutex *sync.Mutex) {
	defer func() {
		atomic.AddInt32(activeCount, -1)
		// 如果没有活跃任务且通道为空，关闭通道
		if atomic.LoadInt32(activeCount) == 0 {
			mutex.Lock()
			if atomic.LoadInt32(activeCount) == 0 && len(taskCh) == 0 {
				close(taskCh)
			}
			mutex.Unlock()
		}
	}()
	for task.Retries <= task.MaxRetries {
		err := downloadFile(task)
		if err == nil {
			// 下载成功
			fmt.Printf("成功下载: %s\n", task.URL)

			// 如果是HTML文件且需要递归下载
			if isHTMLFile(task.TargetPath) && maxDepth > 0 {
				// 提取链接和资源
				links, resources, err := extractLinksAndResources(task.TargetPath, task.URL)
				if err != nil {
					fmt.Printf("提取链接失败 %s: %v\n", task.URL, err)
				} else {
					// 递归下载链接
					for _, link := range links {
						// 检查是否已下载过该URL
						if _, loaded := task.DownloadedURL.LoadOrStore(link, true); loaded {
							continue
						}

						// 创建新的下载任务
						parsedURL, _ := url.Parse(link)
						urlDir := filepath.Join(filepath.Dir(task.TargetPath), getPathFromURL(parsedURL))
						fileName := filepath.Base(parsedURL.Path)
						if fileName == "" || fileName == "." || fileName == "/" || fileName == "//" {
							fileName = "index.html"
						}

						newTask := &DownloadTask{
							URL:           link,
							TargetPath:    filepath.Join(urlDir, fileName),
							MaxRetries:    task.MaxRetries,
							DownloadedURL: task.DownloadedURL,
						}

						// 发送到任务通道
						//mutex.Lock() // 不需要加锁，因为activeCount是原子计数器
						atomic.AddInt32(activeCount, 1)
						//mutex.Unlock() // 不需要解锁，因为activeCount是原子计数器
						taskCh <- newTask
					}

					// 下载资源文件
					if downloadResources {
						for _, res := range resources {
							// 检查是否已下载过该资源
							if _, loaded := task.DownloadedURL.LoadOrStore(res, true); loaded {
								continue
							}
							// 创建资源下载任务
							parsedURL, _ := url.Parse(res)
							resDir := filepath.Join(filepath.Dir(task.TargetPath), getPathFromURL(parsedURL))
							fileName := filepath.Base(parsedURL.Path)
							if fileName == "" || fileName == "." {
								fileName = "resource"
							}

							resTask := &DownloadTask{
								URL:           res,
								TargetPath:    filepath.Join(resDir, fileName),
								MaxRetries:    task.MaxRetries,
								DownloadedURL: task.DownloadedURL,
							}

							// 发送到任务通道
							//mutex.Lock() // 不需要加锁，因为activeCount是原子计数器
							atomic.AddInt32(activeCount, 1)
							//mutex.Unlock() // 不需要解锁，因为activeCount是原子计数器
							taskCh <- resTask
						}
					}
				}
			}
			return
		}

		// 下载失败，重试
		task.Retries++
		fmt.Printf("下载失败 %s: %v (重试 %d/%d)\n", task.URL, err, task.Retries, task.MaxRetries)
		time.Sleep(time.Second * 2) // 重试前等待
	}
	fmt.Printf("下载失败，已达到最大重试次数: %s\n", task.URL)
}

func downloadFile(task *DownloadTask) error {
	// 创建请求
	req, err := http.NewRequest("GET", task.URL, nil)
	if err != nil {
		return err
	}

	// 设置User-Agent
	req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")

	// 检查文件是否已存在，获取已下载的大小
	fi, err := os.Stat(task.TargetPath)
	if err == nil {
		task.CurrentSize = fi.Size()
	} else if !os.IsNotExist(err) {
		return err
	}

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()
	//判断返回网页大小和文件大小
	if resp.ContentLength > task.CurrentSize {
		req.Header.Set("Range", fmt.Sprintf("bytes=%d-", task.CurrentSize))
	}

	// 发送请求
	resp, err = http.DefaultClient.Do(req)
	if err != nil {
		return err
	}
	defer resp.Body.Close()

	// 检查响应状态
	if resp.StatusCode >= 400 {
		return fmt.Errorf("服务器返回错误: %s", resp.Status)
	}

	// 获取文件总大小
	if task.TotalSize == 0 {
		if resp.ContentLength > 0 {
			task.TotalSize = resp.ContentLength + task.CurrentSize
		} else {
			task.TotalSize = -1 // 未知大小
		}
	}

	// 确保目录存在
	if err := os.MkdirAll(filepath.Dir(task.TargetPath), 0755); err != nil {
		return err
	}

	// 创建或打开文件
	var file *os.File
	if task.CurrentSize > 0 && task.CurrentSize < task.TotalSize {
		file, err = os.OpenFile(task.TargetPath, os.O_APPEND|os.O_WRONLY, 0644)
	} else {
		file, err = os.Create(task.TargetPath)
	}
	if err != nil {
		return err
	}
	defer file.Close()

	// 创建进度条
	if task.Bar == nil {
		if task.TotalSize > 0 {
			task.Bar = pb.Full.Start64(task.TotalSize)
		} else {
			task.Bar = pb.New(-1)
			task.Bar.Start()
		}
		task.Bar.Set("prefix", filepath.Base(task.URL)+" ")
		task.Bar.SetCurrent(task.CurrentSize)
	}

	// 创建代理读取器来更新进度条
	reader := &ProgressReader{
		Reader: resp.Body,
		Bar:    task.Bar,
	}

	// 写入文件
	written, err := io.Copy(file, reader)
	if err != nil {
		return err
	}

	// 更新已下载大小
	task.CurrentSize += written

	// 检查是否下载完成
	if task.TotalSize > 0 && task.CurrentSize >= task.TotalSize {
		task.Completed = true
		task.Bar.Finish()
	}

	return nil
}

// ProgressReader 用于更新进度条
type ProgressReader struct {
	Reader io.Reader
	Bar    *pb.ProgressBar
}

func (pr *ProgressReader) Read(p []byte) (n int, err error) {
	n, err = pr.Reader.Read(p)
	if n > 0 {
		pr.Bar.Add(n)
	}
	return
}

// 从HTML文件中提取链接和资源
func extractLinksAndResources(filePath, baseURL string) ([]string, []string, error) {
	// 读取文件内容
	file, err := os.Open(filePath)
	if err != nil {
		return nil, nil, err
	}
	defer file.Close()

	// 使用goquery解析HTML
	doc, err := goquery.NewDocumentFromReader(file)
	if err != nil {
		return nil, nil, err
	}

	baseURLObj, err := url.Parse(baseURL)
	if err != nil {
		return nil, nil, err
	}

	links := make([]string, 0)
	resources := make([]string, 0)

	// 提取链接 (a标签)
	doc.Find("a").Each(func(i int, s *goquery.Selection) {
		href, exists := s.Attr("href")
		if exists {
			absURL := resolveURL(baseURLObj, href)
			if absURL != "" && isSameDomain(baseURLObj, absURL) {
				links = append(links, absURL)
			}
		}
	})

	// 提取图片 (img标签)
	doc.Find("img").Each(func(i int, s *goquery.Selection) {
		src, exists := s.Attr("src")
		if exists {
			absURL := resolveURL(baseURLObj, src)
			if absURL != "" {
				resources = append(resources, absURL)
			}
		}
	})

	// 提取CSS (link标签)
	doc.Find("link").Each(func(i int, s *goquery.Selection) {
		rel, _ := s.Attr("rel")
		href, exists := s.Attr("href")
		if exists && strings.ToLower(rel) == "stylesheet" {
			absURL := resolveURL(baseURLObj, href)
			if absURL != "" {
				resources = append(resources, absURL)
			}
		}
	})

	// 提取JavaScript (script标签)
	doc.Find("script").Each(func(i int, s *goquery.Selection) {
		src, exists := s.Attr("src")
		if exists {
			absURL := resolveURL(baseURLObj, src)
			if absURL != "" {
				resources = append(resources, absURL)
			}
		}
	})

	// 提取视频和音频 (video, audio标签)
	doc.Find("video source, audio source").Each(func(i int, s *goquery.Selection) {
		src, exists := s.Attr("src")
		if exists {
			absURL := resolveURL(baseURLObj, src)
			if absURL != "" {
				resources = append(resources, absURL)
			}
		}
	})

	// 提取CSS中的URL (内联样式和样式表)
	doc.Find("style").Each(func(i int, s *goquery.Selection) {
		cssText := s.Text()
		cssURLs := extractCSSURLs(cssText, baseURLObj)
		resources = append(resources, cssURLs...)
	})

	// 去重
	links = removeDuplicates(links)
	resources = removeDuplicates(resources)

	return links, resources, nil
}

// 从CSS文本中提取URL
func extractCSSURLs(cssText string, baseURL *url.URL) []string {
	urls := make([]string, 0)
	re := regexp.MustCompile(`url\(['"]?([^'")]+)['"]?\)`)
	matches := re.FindAllStringSubmatch(cssText, -1)

	for _, match := range matches {
		if len(match) > 1 {
			absURL := resolveURL(baseURL, match[1])
			if absURL != "" {
				urls = append(urls, absURL)
			}
		}
	}

	return urls
}

// 解析相对URL为绝对URL
func resolveURL(base *url.URL, href string) string {
	// 忽略锚点链接和javascript
	if strings.HasPrefix(href, "#") || strings.HasPrefix(href, "javascript:") {
		return ""
	}

	relURL, err := url.Parse(href)
	if err != nil {
		return ""
	}

	absURL := base.ResolveReference(relURL)
	return absURL.String()
}

// 检查URL是否属于同一域名
func isSameDomain(baseURL *url.URL, urlStr string) bool {
	u, err := url.Parse(urlStr)
	if err != nil {
		return false
	}
	return u.Hostname() == baseURL.Hostname()
}

// 去除重复项
func removeDuplicates(urls []string) []string {
	seen := make(map[string]bool)
	result := make([]string, 0, len(urls))

	for _, u := range urls {
		// 计算URL的MD5作为唯一标识
		hasher := md5.New()
		hasher.Write([]byte(u))
		urlHash := hex.EncodeToString(hasher.Sum(nil))

		if !seen[urlHash] {
			seen[urlHash] = true
			result = append(result, u)
		}
	}

	return result
}

// 检查文件是否为HTML文件
func isHTMLFile(filePath string) bool {
	ext := strings.ToLower(filepath.Ext(filePath))
	return ext == ".html" || ext == ".htm" || ext == ""
}

// 从URL获取相对路径
func getPathFromURL(parsedURL *url.URL) string {
	path := parsedURL.Path
	if path == "" || path == "/" || path == "//" {
		return ""
	}

	// 移除开头的斜杠
	path = strings.TrimPrefix(path, "/")

	// 获取目录部分
	dir := filepath.Dir(path)
	if dir == "." {
		return ""
	}

	return dir
}
